mmdetection源码详细解读 您所在的位置:网站首页 show result pyplot mmdetection源码详细解读

mmdetection源码详细解读

#mmdetection源码详细解读| 来源: 网络整理| 查看: 265

目录 简介1. 测试代码2. mmdetection/mmdet/apis2.1 mmdetection/mmdet/apis/inference.py2.1.1 init_detector2.1.2 inference_detector 3. build_detector()

简介 GitHub地址:https://github.com/open-mmlab/mmdetection.各模型的权重可以在model_zoo.md上下载。mmdetection官方使用教程https://mmdetection.readthedocs.io/en/latest/(强烈建议) 1. 测试代码 import mmcv from mmdet.apis import init_detector, inference_detector, show_result_pyplot config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # build the model from a config file and a checkpoint file model = init_detector(config_file, checkpoint_file, device='cuda:1') # test a single image and show the results img = 'demo.jpg' # or img = mmcv.imread(img), which will only load it once result = inference_detector(model, img) # visualize the results in a new window show_result_pyplot(model, img, result)

新建一个工程,其目录结构为: 在这里插入图片描述 实际上,我把mmdetection看成是第三方库(类比numpy、cv2等),所以我把下载下来的GitHub代码解压到了python的工程文件中,即 在这里插入图片描述 但我看有的人直接把mmdetection当做是工程文件也是没问题的,比如D2Det的官方代码就是如此:你可以发现其目录结构和mmdetection的源码结构是一样的。 在这里插入图片描述

2. mmdetection/mmdet/apis

该文件夹下有三个文件:

inference.py:用于初始化模型、前向推理、读取图片、显示检测结果等。train:用于训练。test:用于测试。 在这里插入图片描述 2.1 mmdetection/mmdet/apis/inference.py 2.1.1 init_detector def init_detector(config, checkpoint=None, device='cuda:0'): """Initialize a detector from config file. Args: config (str or :obj:`mmcv.Config`): Config file path or the config object. checkpoint (str, optional): Checkpoint path. If left as None, the model will not load any weights. Returns: nn.Module: The constructed detector. """ if isinstance(config, str): # 如果config是配置文件路径,则用mmcv.Config.fromfile()读取, # 并返回类mmcv.Config()实例化后的object # 如果config已经是类mmcv.Config()实例化后的object,则不需要其他操作。 config = mmcv.Config.fromfile(config) elif not isinstance(config, mmcv.Config): raise TypeError('config must be a filename or Config object, ' f'but got {type(config)}') config.model.pretrained = None model = build_detector(config.model, test_cfg=config.test_cfg) if checkpoint is not None: map_loc = 'cpu' if device == 'cpu' else None checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc) if 'CLASSES' in checkpoint['meta']: model.CLASSES = checkpoint['meta']['CLASSES'] else: warnings.simplefilter('once') warnings.warn('Class names are not saved in the checkpoint\'s ' 'meta data, use COCO classes by default.') model.CLASSES = get_classes('coco') model.cfg = config # save the config in the model for convenience model.to(device) model.eval() return model

函数作用、输入参数、输出参数直接见注解。这里给出几个例子:

from mmdet.apis import init_detector, inference_detector, show_result_pyplot config_file = 'configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py' checkpoint_file = 'checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth' # build the model from a config file and a checkpoint file model = init_detector(config_file, checkpoint_file, device='cuda:0') # device 可以是'cpu'、'cuda:0'、'cuda:1'等。 2.1.2 inference_detector def inference_detector(model, img): """Inference image(s) with the detector. Args: model (nn.Module): The loaded detector. imgs (str/ndarray or list[str/ndarray]): Either image files or loaded images. Returns: If imgs is a str, a generator will be returned, otherwise return the detection results directly. """ cfg = model.cfg device = next(model.parameters()).device # model device # build the data pipeline test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data data = dict(img=img) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] else: # Use torchvision ops for CPU mode instead for m in model.modules(): if isinstance(m, (RoIPool, RoIAlign)): if not m.aligned: # aligned=False is not implemented on CPU # set use_torchvision on-the-fly m.use_torchvision = True warnings.warn('We set use_torchvision=True in CPU mode.') # just get the actual data from DataContainer data['img_metas'] = data['img_metas'][0].data # forward the model with torch.no_grad(): result = model(return_loss=False, rescale=True, **data) return result 3. build_detector()

model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)

根据配置文件构建神经网络

以configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py配置文件为例,cfg.model,cfg.train_cfg,cfg.test_cfg均为字典类型,分别与配置文件中的内容相对应。

DETECTORS = Registry('detector') def build_detector(cfg, train_cfg=None, test_cfg=None): """Build detector.""" return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) Registry是一个大工厂,工厂中包含了很多小仓库,共有7个小仓库,这些仓库在train.py开始import 模块时就自动创建,并且每个仓库中都放了各种样的小商品,比如在detectors/__init__.py中可以看到,仓库DETECTORS的小商品有:。 BACKBONES = Registry('backbone') NECKS = Registry('neck') ROI_EXTRACTORS = Registry('roi_extractor') SHARED_HEADS = Registry('shared_head') HEADS = Registry('head') LOSSES = Registry('loss') DETECTORS = Registry('detector') # detectors/__init__.py from .atss import ATSS from .base import BaseDetector from .cascade_rcnn import CascadeRCNN from .fast_rcnn import FastRCNN from .faster_rcnn import FasterRCNN from .fcos import FCOS from .fovea import FOVEA from .fsaf import FSAF from .gfl import GFL from .grid_rcnn import GridRCNN from .htc import HybridTaskCascade from .mask_rcnn import MaskRCNN from .mask_scoring_rcnn import MaskScoringRCNN from .nasfcos import NASFCOS from .point_rend import PointRend from .reppoints_detector import RepPointsDetector from .retinanet import RetinaNet from .rpn import RPN from .single_stage import SingleStageDetector from .two_stage import TwoStageDetector 此时参数cfg是字典类型,执行build_from_cfg() def build(cfg, registry, default_args=None): if isinstance(cfg, list): modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg] return nn.Sequential(*modules) else: return build_from_cfg(cfg, registry, default_args) 比如在fcos.py文件中,修饰器@DETECTORS.register_module()的作用就是将创建好的小商品FCOS放入仓库DETECTORS中。 @DETECTORS.register_module() class FCOS(SingleStageDetector): """Implementation of `FCOS `_""" def __init__(self, backbone, neck, bbox_head, train_cfg=None, test_cfg=None, pretrained=None): super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg, pretrained) 在register_module()函数中,你可以看到它的用法,小商品的名字就是在注册时指定的,而注册时的名字就是类名FCOS或ResNet等。所以你要创建自己的检测器,你需要构建一个类,把这个类当做小商品,然后将其放在相应的仓库中(注册)。 def register_module(self, name=None, force=False, module=None): """Register a module. A record will be added to `self._module_dict`, whose key is the class name or the specified name, and value is the class itself. It can be used as a decorator or a normal function. Example: >>> backbones = Registry('backbone') >>> @backbones.register_module() >>> class ResNet: >>> pass >>> backbones = Registry('backbone') >>> @backbones.register_module(name='mnet') >>> class MobileNet: >>> pass >>> backbones = Registry('backbone') >>> class ResNet: >>> pass >>> backbones.register_module(ResNet) obj_type='FCOS',在 def build_from_cfg(cfg, registry, default_args=None): """Build a module from config dict.""" args = cfg.copy() obj_type = args.pop('type') # 'FCOS' if is_str(obj_type): obj_cls = registry.get(obj_type) if obj_cls is None: raise KeyError( f'{obj_type} is not in the {registry.name} registry') elif inspect.isclass(obj_type): obj_cls = obj_type else: raise TypeError( f'type must be a str or valid type, but got {type(obj_type)}') if default_args is not None: for name, value in default_args.items(): args.setdefault(name, value) return obj_cls(**args)


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有